{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.hook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model hooks\n",
    "\n",
    "> Callback and helper function to add hooks in models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What are hooks?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hooks are functions you can attach to a particular layer in your model and that will be executed in the forward pass (for forward hooks) or backward pass (for backward hooks). Here we begin with an introduction around hooks, but you should jump to `HookCallback` if you quickly want to implement one (and read the following example `ActivationStats`).\n",
    "\n",
    "Forward hooks are functions that take three arguments: the layer it's applied to, the input of that layer and the output of that layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=5, out_features=3, bias=True) (tensor([[-0.7487, -0.5879, -0.5319, -0.3080, -1.0268],\n",
      "        [-0.3987, -0.9263, -0.9486, -0.4070,  0.9790],\n",
      "        [-0.3519, -0.6336, -2.2617,  0.6636,  0.2626],\n",
      "        [ 1.0500, -1.3234,  0.1581,  0.4866, -1.2595]]),) tensor([[ 0.0179,  0.4064, -0.1843],\n",
      "        [-0.2548, -0.1191, -0.3833],\n",
      "        [-0.4946,  0.9474, -1.2664],\n",
      "        [-0.3220,  1.2831, -0.1166]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "tst_model = nn.Linear(5,3)\n",
    "def example_forward_hook(m,i,o): print(m,i,o)\n",
    "    \n",
    "x = torch.randn(4,5)\n",
    "hook = tst_model.register_forward_hook(example_forward_hook)\n",
    "y = tst_model(x)\n",
    "hook.remove()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Backward hooks are functions that take three arguments: the layer it's applied to, the gradients of the loss with respect to the input, and the gradients with respect to the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=5, out_features=3, bias=True) (tensor([-0.1618,  0.2840, -0.1865]), None, tensor([[-0.0768,  0.1065,  0.0312],\n",
      "        [ 0.0761, -0.3147,  0.0291],\n",
      "        [ 0.1030, -0.1513,  0.1539],\n",
      "        [ 0.0551, -0.0978,  0.0182],\n",
      "        [-0.0336, -0.1903, -0.0547]])) (tensor([[-0.0047,  0.1571, -0.0039],\n",
      "        [-0.0800,  0.0675, -0.0910],\n",
      "        [-0.0473,  0.0637,  0.0063],\n",
      "        [-0.0299, -0.0042, -0.0979]]),)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/wgilliam/miniconda3/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py:1025: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.\n",
      "  warnings.warn(\"Using a non-full backward hook when the forward contains multiple autograd Nodes \"\n"
     ]
    }
   ],
   "source": [
    "def example_backward_hook(m,gi,go): print(m,gi,go)\n",
    "hook = tst_model.register_backward_hook(example_backward_hook)\n",
    "\n",
    "x = torch.randn(4,5)\n",
    "y = tst_model(x)\n",
    "loss = y.pow(2).mean()\n",
    "loss.backward()\n",
    "hook.remove()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hooks can change the input/output of a layer, or the gradients, print values or shapes. If you want to store something related to theses inputs/outputs, it's best to have your hook associated to a class so that it can put it in the state of an instance of that class."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hook -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@docs\n",
    "class Hook():\n",
    "    \"Create a hook on `m` with `hook_func`.\"\n",
    "    def __init__(self, m, hook_func, is_forward=True, detach=True, cpu=False, gather=False):\n",
    "        store_attr('hook_func,detach,cpu,gather')\n",
    "        f = m.register_forward_hook if is_forward else m.register_backward_hook\n",
    "        self.hook = f(self.hook_fn)\n",
    "        self.stored,self.removed = None,False\n",
    "\n",
    "    def hook_fn(self, module, input, output):\n",
    "        \"Applies `hook_func` to `module`, `input`, `output`.\"\n",
    "        if self.detach:\n",
    "            input,output = to_detach(input, cpu=self.cpu, gather=self.gather),to_detach(output, cpu=self.cpu, gather=self.gather)\n",
    "        self.stored = self.hook_func(module, input, output)\n",
    "\n",
    "    def remove(self):\n",
    "        \"Remove the hook from the model.\"\n",
    "        if not self.removed:\n",
    "            self.hook.remove()\n",
    "            self.removed=True\n",
    "\n",
    "    def __enter__(self, *args): return self\n",
    "    def __exit__(self, *args): self.remove()\n",
    "\n",
    "    _docs = dict(__enter__=\"Register the hook\",\n",
    "                 __exit__=\"Remove the hook\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This will be called during the forward pass if `is_forward=True`, the backward pass otherwise, and will optionally `detach`, `gather` and put on the `cpu` the (gradient of the) input/output of the model before passing them to `hook_func`. The result of `hook_func` will be stored in the `stored` attribute of the `Hook`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_model = nn.Linear(5,3)\n",
    "hook = Hook(tst_model, lambda m,i,o: o)\n",
    "y = tst_model(x)\n",
    "test_eq(hook.stored, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hook.hook_fn\" class=\"doc_header\"><code>Hook.hook_fn</code><a href=\"__main__.py#L11\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Hook.hook_fn</code>(**`module`**, **`input`**, **`output`**)\n",
       "\n",
       "Applies `hook_func` to `module`, `input`, `output`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hook.hook_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hook.remove\" class=\"doc_header\"><code>Hook.remove</code><a href=\"__main__.py#L17\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Hook.remove</code>()\n",
       "\n",
       "Remove the hook from the model."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hook.remove)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: It's important to properly remove your hooks for your model when you're done to avoid them being called again next time your model is applied to some inputs, and to free the memory that go with their state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_model = nn.Linear(5,10)\n",
    "x = torch.randn(4,5)\n",
    "y = tst_model(x)\n",
    "hook = Hook(tst_model, example_forward_hook)\n",
    "test_stdout(lambda: tst_model(x), f\"{tst_model} ({x},) {y.detach()}\")\n",
    "hook.remove()\n",
    "test_stdout(lambda: tst_model(x), \"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Context Manager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since it's very important to remove your `Hook` even if your code is interrupted by some bug, `Hook` can be used as context managers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hook.__enter__\" class=\"doc_header\"><code>Hook.__enter__</code><a href=\"__main__.py#L23\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Hook.__enter__</code>(**\\*`args`**)\n",
       "\n",
       "Register the hook"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hook.__enter__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hook.__exit__\" class=\"doc_header\"><code>Hook.__exit__</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Hook.__exit__</code>(**\\*`args`**)\n",
       "\n",
       "Remove the hook"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hook.__exit__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_model = nn.Linear(5,10)\n",
    "x = torch.randn(4,5)\n",
    "y = tst_model(x)\n",
    "with Hook(tst_model, example_forward_hook) as h:\n",
    "    test_stdout(lambda: tst_model(x), f\"{tst_model} ({x},) {y.detach()}\")\n",
    "test_stdout(lambda: tst_model(x), \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _hook_inner(m,i,o): return o if isinstance(o,Tensor) or is_listy(o) else list(o)\n",
    "\n",
    "def hook_output(module, detach=True, cpu=False, grad=False):\n",
    "    \"Return a `Hook` that stores activations of `module` in `self.stored`\"\n",
    "    return Hook(module, _hook_inner, detach=detach, cpu=cpu, is_forward=not grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The activations stored are the gradients if `grad=True`, otherwise the output of `module`. If `detach=True` they are detached from their history, and if `cpu=True`, they're put on the CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_model = nn.Linear(5,10)\n",
    "x = torch.randn(4,5)\n",
    "with hook_output(tst_model) as h:\n",
    "    y = tst_model(x)\n",
    "    test_eq(y, h.stored)\n",
    "    assert not h.stored.requires_grad\n",
    "    \n",
    "with hook_output(tst_model, grad=True) as h:\n",
    "    y = tst_model(x)\n",
    "    loss = y.pow(2).mean()\n",
    "    loss.backward()\n",
    "    test_close(2*y / y.numel(), h.stored[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#cuda\n",
    "with hook_output(tst_model, cpu=True) as h:\n",
    "    y = tst_model.cuda()(x.cuda())\n",
    "    test_eq(h.stored.device, torch.device('cpu'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hooks -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@docs\n",
    "class Hooks():\n",
    "    \"Create several hooks on the modules in `ms` with `hook_func`.\"\n",
    "    def __init__(self, ms, hook_func, is_forward=True, detach=True, cpu=False):\n",
    "        self.hooks = [Hook(m, hook_func, is_forward, detach, cpu) for m in ms]\n",
    "\n",
    "    def __getitem__(self,i): return self.hooks[i]\n",
    "    def __len__(self):       return len(self.hooks)\n",
    "    def __iter__(self):      return iter(self.hooks)\n",
    "    @property\n",
    "    def stored(self):        return L(o.stored for o in self)\n",
    "\n",
    "    def remove(self):\n",
    "        \"Remove the hooks from the model.\"\n",
    "        for h in self.hooks: h.remove()\n",
    "\n",
    "    def __enter__(self, *args): return self\n",
    "    def __exit__ (self, *args): self.remove()\n",
    "\n",
    "    _docs = dict(stored = \"The states saved in each hook.\",\n",
    "                 __enter__=\"Register the hooks\",\n",
    "                 __exit__=\"Remove the hooks\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]\n",
    "tst_model = nn.Sequential(*layers)\n",
    "hooks = Hooks(tst_model, lambda m,i,o: o)\n",
    "y = tst_model(x)\n",
    "test_eq(hooks.stored[0], layers[0](x))\n",
    "test_eq(hooks.stored[1], F.relu(layers[0](x)))\n",
    "test_eq(hooks.stored[2], y)\n",
    "hooks.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hooks.stored\" class=\"doc_header\"><code>Hooks.stored</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "The states saved in each hook."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hooks.stored, name='Hooks.stored')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hooks.remove\" class=\"doc_header\"><code>Hooks.remove</code><a href=\"__main__.py#L14\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Hooks.remove</code>()\n",
       "\n",
       "Remove the hooks from the model."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hooks.remove)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Context Manager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like `Hook` , you can use `Hooks` as context managers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hooks.__enter__\" class=\"doc_header\"><code>Hooks.__enter__</code><a href=\"__main__.py#L18\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Hooks.__enter__</code>(**\\*`args`**)\n",
       "\n",
       "Register the hooks"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hooks.__enter__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"Hooks.__exit__\" class=\"doc_header\"><code>Hooks.__exit__</code><a href=\"__main__.py#L19\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>Hooks.__exit__</code>(**\\*`args`**)\n",
       "\n",
       "Remove the hooks"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(Hooks.__exit__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]\n",
    "tst_model = nn.Sequential(*layers)\n",
    "with Hooks(layers, lambda m,i,o: o) as h:\n",
    "    y = tst_model(x)\n",
    "    test_eq(h.stored[0], layers[0](x))\n",
    "    test_eq(h.stored[1], F.relu(layers[0](x)))\n",
    "    test_eq(h.stored[2], y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def hook_outputs(modules, detach=True, cpu=False, grad=False):\n",
    "    \"Return `Hooks` that store activations of all `modules` in `self.stored`\"\n",
    "    return Hooks(modules, _hook_inner, detach=detach, cpu=cpu, is_forward=not grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The activations stored are the gradients if `grad=True`, otherwise the output of `modules`. If `detach=True` they are detached from their history, and if `cpu=True`, they're put on the CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]\n",
    "tst_model = nn.Sequential(*layers)\n",
    "x = torch.randn(4,5)\n",
    "with hook_outputs(layers) as h:\n",
    "    y = tst_model(x)\n",
    "    test_eq(h.stored[0], layers[0](x))\n",
    "    test_eq(h.stored[1], F.relu(layers[0](x)))\n",
    "    test_eq(h.stored[2], y)\n",
    "    for s in h.stored: assert not s.requires_grad\n",
    "    \n",
    "with hook_outputs(layers, grad=True) as h:\n",
    "    y = tst_model(x)\n",
    "    loss = y.pow(2).mean()\n",
    "    loss.backward()\n",
    "    g = 2*y / y.numel()\n",
    "    test_close(g, h.stored[2][0])\n",
    "    g = g @ layers[2].weight.data\n",
    "    test_close(g, h.stored[1][0])\n",
    "    g = g * (layers[0](x) > 0).float()\n",
    "    test_close(g, h.stored[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#cuda\n",
    "with hook_outputs(tst_model, cpu=True) as h:\n",
    "    y = tst_model.cuda()(x.cuda())\n",
    "    for s in h.stored: test_eq(s.device, torch.device('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def dummy_eval(m, size=(64,64)):\n",
    "    \"Evaluate `m` on a dummy input of a certain `size`\"\n",
    "    ch_in = in_channels(m)\n",
    "    x = one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)\n",
    "    with torch.no_grad(): return m.eval()(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def model_sizes(m, size=(64,64)):\n",
    "    \"Pass a dummy input through the model `m` to get the various sizes of activations.\"\n",
    "    with hook_outputs(m) as hooks:\n",
    "        _ = dummy_eval(m, size=size)\n",
    "        return [o.stored.shape for o in hooks]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))\n",
    "test_eq(model_sizes(m), [[1, 16, 64, 64], [1, 32, 32, 32], [1, 32, 32, 32]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def num_features_model(m):\n",
    "    \"Return the number of output features for `m`.\"\n",
    "    sz,ch_in = 32,in_channels(m)\n",
    "    while True:\n",
    "        #Trying for a few sizes in case the model requires a big input size.\n",
    "        try:\n",
    "            return model_sizes(m, (sz,sz))[-1][1]\n",
    "        except Exception as e:\n",
    "            sz *= 2\n",
    "            if sz > 2048: raise e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))\n",
    "test_eq(num_features_model(m), 3)\n",
    "m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))\n",
    "test_eq(num_features_model(m), 32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HookCallback -"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make hooks easy to use, we wrapped a version in a Callback where you just have to implement a `hook` function (plus any element you might need)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def has_params(m):\n",
    "    \"Check if `m` has at least one parameter\"\n",
    "    return len(list(m.parameters())) > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert has_params(nn.Linear(3,4))\n",
    "assert has_params(nn.LSTM(4,5,2))\n",
    "assert not has_params(nn.ReLU())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@funcs_kwargs\n",
    "class HookCallback(Callback):\n",
    "    \"`Callback` that can be used to register hooks on `modules`\"\n",
    "    _methods = [\"hook\"]\n",
    "    hook = noops\n",
    "    def __init__(self, modules=None, every=None, remove_end=True, is_forward=True, detach=True, cpu=True, include_paramless=False , **kwargs):\n",
    "        store_attr('modules,every,remove_end,is_forward,detach,cpu, include_paramless')\n",
    "        assert not kwargs\n",
    "\n",
    "    def before_fit(self):\n",
    "        \"Register the `Hooks` on `self.modules`.\"\n",
    "        if self.modules is None: self.modules = [m for m in flatten_model(self.model) if self.include_paramless or has_params(m)]\n",
    "        if self.every is None: self._register()\n",
    "\n",
    "    def before_batch(self):\n",
    "        if self.every is None: return\n",
    "        if self.training and self.train_iter%self.every==0: self._register()\n",
    "\n",
    "    def after_batch(self):\n",
    "        if self.every is None: return\n",
    "        if self.training and self.train_iter%self.every==0: self._remove()\n",
    "\n",
    "    def after_fit(self):\n",
    "        \"Remove the `Hooks`.\"\n",
    "        if self.remove_end: self._remove()\n",
    "\n",
    "    def _register(self): self.hooks = Hooks(self.modules, self.hook, self.is_forward, self.detach, self.cpu)\n",
    "    def _remove(self):\n",
    "        if getattr(self, 'hooks', None): self.hooks.remove()\n",
    "\n",
    "    def __del__(self): self._remove()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can either subclass and implement a `hook` function (along with any event you want) or pass that a `hook` function when initializing. Such a function needs to take three argument: a layer, input and output (for a backward hook, input means gradient with respect to the inputs, output, gradient with respect to the output) and can either modify them or update the state according to them.\n",
    "\n",
    "If not provided, `modules` will default to the layers of `self.model` that have a `weight` attribute. (to include layers of `self.model` that *do not* have a `weight` attribute e.g `ReLU`, `Flatten` etc., set `include_paramless=True`)\n",
    "Depending on `do_remove`, the hooks will be properly removed at the end of training (or in case of error). `is_forward` , `detach` and `cpu` are passed to `Hooks`.\n",
    "\n",
    "The function called at each forward (or backward) pass is `self.hook` and must be implemented when subclassing this callback."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 15.194612503051758, 12.082895278930664, '00:00']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/wgilliam/development/projects/fastai/nbs/fastai/callback/core.py:51: UserWarning: You are shadowing an attribute (modules) that exists in the learner. Use `self.learn.modules` to avoid this\n",
      "  warn(f\"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this\")\n"
     ]
    }
   ],
   "source": [
    "class TstCallback(HookCallback):\n",
    "    def hook(self, m, i, o): return o\n",
    "    def after_batch(self): test_eq(self.hooks.stored[0], self.pred)\n",
    "        \n",
    "learn = synth_learner(n_trn=5, cbs = TstCallback())\n",
    "learn.fit(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1.9420239925384521, 1.7295372486114502, '00:00']\n"
     ]
    }
   ],
   "source": [
    "class TstCallback(HookCallback):\n",
    "    def __init__(self, modules=None, remove_end=True, detach=True, cpu=False):\n",
    "        super().__init__(modules, None, remove_end, False, detach, cpu)\n",
    "    def hook(self, m, i, o): return o\n",
    "    def after_batch(self):\n",
    "        if self.training:\n",
    "            test_eq(self.hooks.stored[0][0], 2*(self.pred-self.y)/self.pred.shape[0])\n",
    "        \n",
    "learn = synth_learner(n_trn=5, cbs = TstCallback())\n",
    "learn.fit(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"HookCallback.before_fit\" class=\"doc_header\"><code>HookCallback.before_fit</code><a href=\"__main__.py#L11\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>HookCallback.before_fit</code>()\n",
       "\n",
       "Register the [`Hooks`](/callback.hook.html#Hooks) on `self.modules`."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(HookCallback.before_fit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"HookCallback.after_fit\" class=\"doc_header\"><code>HookCallback.after_fit</code><a href=\"__main__.py#L24\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>HookCallback.after_fit</code>()\n",
       "\n",
       "Remove the [`Hooks`](/callback.hook.html#Hooks)."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(HookCallback.after_fit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def total_params(m):\n",
    "    \"Give the number of parameters of a module and if it's trainable or not\"\n",
    "    params = sum([p.numel() for p in m.parameters()])\n",
    "    trains = [p.requires_grad for p in m.parameters()]\n",
    "    return params, (False if len(trains)==0 else trains[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(total_params(nn.Linear(10,32)), (32*10+32,True))\n",
    "test_eq(total_params(nn.Linear(10,32, bias=False)), (32*10,True))\n",
    "test_eq(total_params(nn.BatchNorm2d(20)), (20*2, True))\n",
    "test_eq(total_params(nn.BatchNorm2d(20, affine=False)), (0,False))\n",
    "test_eq(total_params(nn.Conv2d(16, 32, 3)), (16*32*3*3 + 32, True))\n",
    "test_eq(total_params(nn.Conv2d(16, 32, 3, bias=False)), (16*32*3*3, True))\n",
    "#First ih layer 20--10, all else 10--10. *4 for the four gates\n",
    "test_eq(total_params(nn.LSTM(20, 10, 2)), (4 * (20*10 + 10) + 3 * 4 * (10*10 + 10), True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def layer_info(learn, *xb):\n",
    "    \"Return layer infos of `model` on `xb` (only support batch first inputs)\"\n",
    "    def _track(m, i, o): \n",
    "        params, trainable, shape = '', '', ''\n",
    "        same = any((isinstance(x[0], torch.Tensor) and x[0].shape[1:] == x[1].shape for x in zip(i, o)))\n",
    "        shape = apply(lambda x: x.shape, o)\n",
    "        if hasattr(m, 'weight'): # non activation layer\n",
    "            params, trainable = total_params(m)\n",
    "        return (type(m).__name__, params, trainable, shape, same)\n",
    "            \n",
    "    with Hooks(flatten_model(learn.model), _track) as h:\n",
    "        batch = apply(lambda o:o[:1], xb)\n",
    "        train_only_cbs = [cb for cb in learn.cbs if hasattr(cb, '_only_train_loop')]\n",
    "        with learn.removed_cbs(train_only_cbs), learn.no_logging(), learn as l:\n",
    "            r = l.get_preds(dl=[batch], inner=True, reorder=False)\n",
    "        return h.stored"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output of `_track` is expected to be a `tuple` of module name, the number of parameters, the shape of the layer, whether it is trainable, what layer group it belongs to, and whether or not the size changed. There are three potential groups that can show:\n",
    "* A non-activation layer (Linear, Conv, etc)\n",
    "* An activation layer\n",
    "* A pooling layer\n",
    "\n",
    "Depending on which only part of the output is really returned, otherwise it is `''`. For non-activation layers everything is returned. Activation layers only return a name, the shape and `False` for `same`. Pooling layers will return the name, the new shape, and `False` for `same`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _m(): return nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))\n",
    "sample_input = torch.randn((16, 1))\n",
    "test_eq(layer_info(synth_learner(model=_m()), sample_input), [\n",
    "    ('Linear', 100, True, [1, 50], False),\n",
    "    ('ReLU', '', '', [1,50], True),\n",
    "    ('BatchNorm1d', 100, True, [1, 50], True),\n",
    "    ('Linear', 51, True, [1, 1], False)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "# Test for Flatten\n",
    "def _tst_m(): return nn.Sequential(\n",
    "    nn.Conv2d(1, 2, kernel_size=3, padding=1, stride=2),\n",
    "    nn.ReLU(),\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(8,50), \n",
    "    nn.ReLU(), \n",
    "    nn.BatchNorm1d(50), \n",
    "    nn.Linear(50, 1)\n",
    ")\n",
    "                                                              \n",
    "sample_input = torch.randn((1,1,4,4))\n",
    "test_eq(layer_info(synth_learner(model=_tst_m()), sample_input), [\n",
    "    ('Conv2d', 20, True, [1, 2, 2, 2], False),\n",
    "    ('ReLU', '', '', [1, 2, 2, 2], True),\n",
    "    ('Flatten', '', '', [1, 8], False),\n",
    "    ('Linear', 450, True, [1, 50], False),\n",
    "    ('ReLU', '', '', [1,50], True),\n",
    "    ('BatchNorm1d', 100, True, [1, 50], True),\n",
    "    ('Linear', 51, True, [1, 1], False)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "# Test for multiple inputs model\n",
    "class _2InpModel(Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.seq = nn.Sequential(nn.Linear(2,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))\n",
    "    def forward(self, *inps):\n",
    "        outputs = torch.cat(inps, dim=-1)\n",
    "        return self.seq(outputs)\n",
    "\n",
    "sample_inputs = (torch.randn(16, 1), torch.randn(16, 1))\n",
    "learn = synth_learner(model=_2InpModel())\n",
    "learn.dls.n_inp = 2\n",
    "test_eq(layer_info(learn, *sample_inputs), [\n",
    "    ('Linear', 150, True, [1, 50], False),\n",
    "    ('ReLU', '', '', [1,50], True),\n",
    "    ('BatchNorm1d', 100, True, [1, 50], True),\n",
    "    ('Linear', 51, True, [1, 1], False)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _get_shapes(o, bs): \n",
    "    inp = o[first(o)] if (isinstance(o, dict)) else o\n",
    "    return ' x '.join([str(bs)] + [str(t) for t in inp[1:]])\n",
    "\n",
    "def _print_shapes(o, bs):\n",
    "    if isinstance(o, torch.Size): return _get_shapes(o, bs)\n",
    "    elif isinstance(o, tuple): return _get_shapes(o[0], bs)\n",
    "    else: return str([_print_shapes(x, bs) for x in o])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def module_summary(learn, *xb):\n",
    "    \"Print a summary of `model` using `xb`\"\n",
    "    #Individual parameters wrapped in ParameterModule aren't called through the hooks in `layer_info`,\n",
    "    #  thus are not counted inside the summary\n",
    "    #TODO: find a way to have them counted in param number somehow\n",
    "    infos = layer_info(learn, *xb)\n",
    "    n,bs = 76,find_bs(xb)\n",
    "    inp_sz = _print_shapes(apply(lambda x:x.shape, xb), bs)\n",
    "    res = f\"{type(learn.model).__name__} (Input shape: {inp_sz})\\n\"\n",
    "    res += \"=\" * n + \"\\n\"\n",
    "    res += f\"{'Layer (type)':<20} {'Output Shape':<20} {'Param #':<10} {'Trainable':<10}\\n\"\n",
    "    res += \"=\" * n\n",
    "    ps,trn_ps,j = 0,0,0\n",
    "    infos = [o for o in infos if o is not None] #see comment in previous cell\n",
    "    prev_sz = None\n",
    "    for typ,np,trn,sz,chnged in infos:\n",
    "        if sz is None: continue\n",
    "        if j == 0:\n",
    "            res += f'\\n{\"\":<20} {_print_shapes(sz, bs)[:19]:<20}' # to avoid a double line at the top\n",
    "        if not chnged and not prev_sz == sz and j > 0: res += \"\\n\" + \"_\" * n + \"\\n\" + f'{\"\":<20} {_print_shapes(sz, bs)[:19]:<20}'\n",
    "        j = 1\n",
    "        res += f\"\\n{typ:<20} {'':<20} {np:<10} {str(trn):<10}\"\n",
    "        if np != '':\n",
    "            ps += np\n",
    "            if trn: trn_ps += np\n",
    "        prev_sz = sz\n",
    "    res += \"\\n\" + \"_\" * n + \"\\n\"\n",
    "    res += f\"\\nTotal params: {ps:,}\\n\"\n",
    "    res += f\"Total trainable params: {trn_ps:,}\\n\"\n",
    "    res += f\"Total non-trainable params: {ps - trn_ps:,}\\n\\n\"\n",
    "    return PrettyString(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "def summary(self:Learner):\n",
    "    \"Print a summary of the model, optimizer and loss function.\"\n",
    "    xb = self.dls.train.one_batch()[:getattr(self.dls.train, \"n_inp\", 1)]\n",
    "    res = module_summary(self, *xb)\n",
    "    res += f\"Optimizer used: {self.opt_func}\\nLoss function: {self.loss_func}\\n\\n\"\n",
    "    if self.opt is not None:\n",
    "        res += f\"Model \" + (\"unfrozen\\n\\n\" if self.opt.frozen_idx==0 else f\"frozen up to parameter group #{self.opt.frozen_idx}\\n\\n\")\n",
    "    res += \"Callbacks:\\n\" + '\\n'.join(f\"  - {cb}\" for cb in self.cbs.sorted('order'))\n",
    "    return PrettyString(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential (Input shape: 16 x 1)\n",
       "============================================================================\n",
       "Layer (type)         Output Shape         Param #    Trainable \n",
       "============================================================================\n",
       "                     16 x 50             \n",
       "Linear                                    100        True      \n",
       "ReLU                                                           \n",
       "BatchNorm1d                               100        True      \n",
       "____________________________________________________________________________\n",
       "                     16 x 1              \n",
       "Linear                                    51         True      \n",
       "____________________________________________________________________________\n",
       "\n",
       "Total params: 251\n",
       "Total trainable params: 251\n",
       "Total non-trainable params: 0\n",
       "\n",
       "Optimizer used: functools.partial(<function SGD at 0x7f1c21a2f8b0>, mom=0.9)\n",
       "Loss function: FlattenedLoss of MSELoss()\n",
       "\n",
       "Callbacks:\n",
       "  - TrainEvalCallback\n",
       "  - Recorder"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn = synth_learner(model=_m())\n",
    "learn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential (Input shape: 16 x 1)\n",
       "============================================================================\n",
       "Layer (type)         Output Shape         Param #    Trainable \n",
       "============================================================================\n",
       "                     16 x 50             \n",
       "Linear                                    100        True      \n",
       "ReLU                                                           \n",
       "BatchNorm1d                               100        True      \n",
       "____________________________________________________________________________\n",
       "                     16 x 1              \n",
       "Linear                                    51         True      \n",
       "____________________________________________________________________________\n",
       "\n",
       "Total params: 251\n",
       "Total trainable params: 251\n",
       "Total non-trainable params: 0\n",
       "\n",
       "Optimizer used: functools.partial(<function SGD at 0x7f1c21a2f8b0>, mom=0.9)\n",
       "Loss function: FlattenedLoss of MSELoss()\n",
       "\n",
       "Callbacks:\n",
       "  - TrainEvalCallback\n",
       "  - Recorder"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#hide\n",
    "#cuda\n",
    "learn = synth_learner(model=_m(), cuda=True)\n",
    "learn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_NOutModel (Input shape: 16 x 1)\n",
       "============================================================================\n",
       "Layer (type)         Output Shape         Param #    Trainable \n",
       "============================================================================\n",
       "                     16 x 6              \n",
       "Linear                                    36         True      \n",
       "____________________________________________________________________________\n",
       "\n",
       "Total params: 36\n",
       "Total trainable params: 36\n",
       "Total non-trainable params: 0\n",
       "\n",
       "Optimizer used: functools.partial(<function SGD at 0x7f1c21a2f8b0>, mom=0.9)\n",
       "Loss function: FlattenedLoss of MSELoss()\n",
       "\n",
       "Callbacks:\n",
       "  - TrainEvalCallback\n",
       "  - Recorder"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#hide\n",
    "# Test for multiple output\n",
    "class _NOutModel(Module):\n",
    "    def __init__(self): self.lin = nn.Linear(5, 6)\n",
    "    def forward(self, x1):\n",
    "        x = torch.randn((10, 5))\n",
    "        return x,self.lin(x)\n",
    "\n",
    "learn = synth_learner(model = _NOutModel())\n",
    "learn.summary() # Output Shape should be (50, 16, 256), (1, 16, 256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential (Input shape: 16 x 4)\n",
       "============================================================================\n",
       "Layer (type)         Output Shape         Param #    Trainable \n",
       "============================================================================\n",
       "                     16 x 2              \n",
       "Linear                                    10         True      \n",
       "ReLU                                                           \n",
       "____________________________________________________________________________\n",
       "                     16 x 1              \n",
       "Linear                                    3          True      \n",
       "____________________________________________________________________________\n",
       "\n",
       "Total params: 13\n",
       "Total trainable params: 13\n",
       "Total non-trainable params: 0\n",
       "\n",
       "Optimizer used: <function Adam at 0x7f1c21a2fc10>\n",
       "Loss function: <function l1_loss at 0x7f1c34174550>\n",
       "\n",
       "Callbacks:\n",
       "  - TrainEvalCallback\n",
       "  - Recorder"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#hide\n",
    "# Test for the case (as in Book) when learn.dls.train_ds is a list not fastai.data.core.Datasets\n",
    "train_x = torch.rand((100, 4))\n",
    "train_y = torch.rand((100, 1))\n",
    "\n",
    "valid_x = torch.rand((100, 4))\n",
    "valid_y = torch.rand((100,1))\n",
    "\n",
    "dset = list(zip(train_x,train_y))\n",
    "valid_dset = list(zip(valid_x,valid_y))\n",
    "\n",
    "dl = DataLoader(dset, batch_size=16)\n",
    "val_dl = DataLoader(valid_dset, batch_size=16)\n",
    "\n",
    "dls = DataLoaders(dl, val_dl)\n",
    "\n",
    "simple_net = nn.Sequential(\n",
    "    nn.Linear(4, 2),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(2,1)\n",
    ")\n",
    "learn = Learner(dls, simple_net, loss_func=F.l1_loss)\n",
    "learn.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Activation graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an example of a `HookCallback`, that stores the mean, stds and histograms of activations that go through the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#exports\n",
    "@delegates()\n",
    "class ActivationStats(HookCallback):\n",
    "    \"Callback that record the mean and std of activations.\"\n",
    "    order=-20\n",
    "    def __init__(self, with_hist=False, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.with_hist = with_hist\n",
    "\n",
    "    def before_fit(self):\n",
    "        \"Initialize stats.\"\n",
    "        super().before_fit()\n",
    "        self.stats = L()\n",
    "\n",
    "    def hook(self, m, i, o):\n",
    "        if isinstance(o, tuple): return self.hook_multi_ouput(o)\n",
    "        o = o.float()\n",
    "        res = {'mean': o.mean().item(), 'std': o.std().item(),\n",
    "               'near_zero': (o<=0.05).long().sum().item()/o.numel()}\n",
    "        if self.with_hist: res['hist'] = o.histc(40,0,10)\n",
    "        return res\n",
    "    \n",
    "    def hook_multi_ouput(self,o_tuple):\n",
    "        \"For outputs of RNN which are [nested] tuples of tensors\"\n",
    "        res = []\n",
    "        for o in self._flatten_tuple(o_tuple):\n",
    "            if not(isinstance(o, Tensor)): continue\n",
    "            res.append(self.hook(None, None, o))\n",
    "        return res\n",
    "\n",
    "    def _flatten_tuple(self, o_tuple):\n",
    "        \"Recursively flatten a [nested] tuple\"\n",
    "        res = []\n",
    "        for it in o_tuple:\n",
    "            if isinstance(it, tuple): res += self._flatten_tuple(it)\n",
    "            else: res += [it]\n",
    "        return tuple(res)\n",
    "\n",
    "    def after_batch(self):\n",
    "        \"Take the stored results and puts it in `self.stats`\"\n",
    "        if self.training and (self.every is None or self.train_iter%self.every == 0):\n",
    "            self.stats.append(self.hooks.stored)\n",
    "        super().after_batch()\n",
    "\n",
    "    def layer_stats(self, idx):\n",
    "        lstats = self.stats.itemgot(idx)\n",
    "        return L(lstats.itemgot(o) for o in ('mean','std','near_zero'))\n",
    "\n",
    "    def hist(self, idx):\n",
    "        res = self.stats.itemgot(idx).itemgot('hist')\n",
    "        return torch.stack(tuple(res)).t().float().log1p()\n",
    "\n",
    "    def color_dim(self, idx, figsize=(10,5), ax=None):\n",
    "        \"The 'colorful dimension' plot\"\n",
    "        res = self.hist(idx)\n",
    "        if ax is None: ax = subplots(figsize=figsize)[1][0]\n",
    "        ax.imshow(res, origin='lower')\n",
    "        ax.axis('off')\n",
    "\n",
    "    def plot_layer_stats(self, idx):\n",
    "        _,axs = subplots(1, 3, figsize=(12,3))\n",
    "        for o,ax,title in zip(self.layer_stats(idx),axs,('mean','std','% near zero')):\n",
    "            ax.plot(o)\n",
    "            ax.set_title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 7.0911030769348145, 7.362981796264648, '00:00']\n"
     ]
    }
   ],
   "source": [
    "learn = synth_learner(n_trn=5, cbs = ActivationStats(every=4))\n",
    "learn.fit(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(#2) [[{'mean': 1.1225931644439697, 'std': 0.06165130063891411, 'near_zero': 0.0}],[{'mean': 1.1687103509902954, 'std': 0.09164517372846603, 'near_zero': 0.0}]]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.activation_stats.stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first line contains the means of the outputs of the model for each batch in the training set, the second line their standard deviations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 10.788463592529297, 9.021217346191406, '00:00']\n",
      "[0, 8.412554740905762, 6.59454870223999, '00:00']\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "def test_activation_stats_include_paramless(include_paramless=False):\n",
    "    \"create a learner, fit, then check number of layers\"\n",
    "    modl = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1), nn.Flatten())\n",
    "\n",
    "    learn = synth_learner(model=modl, cbs=ActivationStats(every=4, include_paramless=include_paramless))\n",
    "    learn.fit(1)\n",
    "    \n",
    "    expected_stats_len = 3  \n",
    "    if include_paramless: expected_stats_len = 5 # includes Relu & Flatten\n",
    "    test_eq(expected_stats_len, len(learn.activation_stats.modules))\n",
    "\n",
    "test_activation_stats_include_paramless(include_paramless=True)\n",
    "test_activation_stats_include_paramless(include_paramless=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 12.904561042785645, 10.325477600097656, '00:00']\n",
      "[0, 17.693506240844727, 14.081195831298828, '00:00']\n",
      "[0, 26.353797912597656, 18.434146881103516, '00:00']\n",
      "[0, 13.517264366149902, 11.664457321166992, '00:00']\n",
      "[0, 6.353066444396973, 7.012430191040039, '00:00']\n",
      "[0, 26.239334106445312, 19.789249420166016, '00:00']\n"
     ]
    }
   ],
   "source": [
    "def test_every(n_tr, every):\n",
    "    \"create a learner, fit, then check number of stats collected\"\n",
    "    learn = synth_learner(n_trn=n_tr, cbs=ActivationStats(every=every))\n",
    "    learn.fit(1)\n",
    "    expected_stats_len = math.ceil(n_tr / every)\n",
    "    test_eq(expected_stats_len, len(learn.activation_stats.stats))\n",
    "    \n",
    "for n_tr in [11, 12, 13]:\n",
    "    test_every(n_tr, 4)\n",
    "    test_every(n_tr, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 12.546510696411133, 13.665218353271484, '00:00']\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "class TstCallback(HookCallback):\n",
    "    def hook(self, m, i, o): return o\n",
    "    def before_fit(self):\n",
    "        super().before_fit()\n",
    "        self.means,self.stds = [],[]\n",
    "    \n",
    "    def after_batch(self):\n",
    "        if self.training:\n",
    "            self.means.append(self.hooks.stored[0].mean().item())\n",
    "            self.stds.append (self.hooks.stored[0].std() .item())\n",
    "\n",
    "learn = synth_learner(n_trn=5, cbs = [TstCallback(), ActivationStats()])\n",
    "learn.fit(1)\n",
    "test_eq(learn.activation_stats.stats.itemgot(0).itemgot(\"mean\"), learn.tst.means)\n",
    "test_eq(learn.activation_stats.stats.itemgot(0).itemgot(\"std\"),  learn.tst.stds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.image_sequence.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.azureml.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted app_examples.ipynb.\n",
      "Converted camvid.ipynb.\n",
      "Converted migrating_catalyst.ipynb.\n",
      "Converted migrating_ignite.ipynb.\n",
      "Converted migrating_lightning.ipynb.\n",
      "Converted migrating_pytorch.ipynb.\n",
      "Converted migrating_pytorch_verbose.ipynb.\n",
      "Converted ulmfit.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
